{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "This is a companion notebook for the book [Deep Learning with Python, Third Edition](https://www.manning.com/books/deep-learning-with-python-third-edition). For readability, it only contains runnable code blocks and section titles, and omits everything else in the book: text paragraphs, figures, and pseudocode.\n\n**If you want to be able to follow what's going on, I recommend reading the notebook side by side with your copy of the book.**\n\nThe book's contents are available online at [deeplearningwithpython.io](https://deeplearningwithpython.io)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "!pip install keras keras-hub --upgrade -q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"KERAS_BACKEND\"] = \"jax\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "cellView": "form",
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "# @title\n",
    "import os\n",
    "from IPython.core.magic import register_cell_magic\n",
    "\n",
    "@register_cell_magic\n",
    "def backend(line, cell):\n",
    "    current, required = os.environ.get(\"KERAS_BACKEND\", \"\"), line.split()[-1]\n",
    "    if current == required:\n",
    "        get_ipython().run_cell(cell)\n",
    "    else:\n",
    "        print(\n",
    "            f\"This cell requires the {required} backend. To run it, change KERAS_BACKEND to \"\n",
    "            f\"\\\"{required}\\\" at the top of the notebook, restart the runtime, and rerun the notebook.\"\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Introduction to TensorFlow, PyTorch, JAX, and Keras"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### A brief history of deep learning frameworks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### How these frameworks relate to each other"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### Introduction to TensorFlow"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### First steps with TensorFlow"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "##### Tensors and variables in TensorFlow"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "###### Constant tensors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "tf.ones(shape=(2, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "tf.zeros(shape=(2, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "tf.constant([1, 2, 3], dtype=\"float32\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "###### Random tensors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "x = tf.random.normal(shape=(3, 1), mean=0., stddev=1.)\n",
    "print(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "x = tf.random.uniform(shape=(3, 1), minval=0., maxval=1.)\n",
    "print(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "###### Tensor assignment and the Variable class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "x = np.ones(shape=(2, 2))\n",
    "x[0, 0] = 0.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "v = tf.Variable(initial_value=tf.random.normal(shape=(3, 1)))\n",
    "print(v)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "v.assign(tf.ones((3, 1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "v[0, 0].assign(3.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "v.assign_add(tf.ones((3, 1)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "##### Tensor operations: Doing math in TensorFlow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "a = tf.ones((2, 2))\n",
    "b = tf.square(a)\n",
    "c = tf.sqrt(a)\n",
    "d = b + c\n",
    "e = tf.matmul(a, b)\n",
    "f = tf.concat((a, b), axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "def dense(inputs, W, b):\n",
    "    return tf.nn.relu(tf.matmul(inputs, W) + b)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "##### Gradients in TensorFlow: A second look at the GradientTape API"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "input_var = tf.Variable(initial_value=3.0)\n",
    "with tf.GradientTape() as tape:\n",
    "    result = tf.square(input_var)\n",
    "gradient = tape.gradient(result, input_var)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "input_const = tf.constant(3.0)\n",
    "with tf.GradientTape() as tape:\n",
    "    tape.watch(input_const)\n",
    "    result = tf.square(input_const)\n",
    "gradient = tape.gradient(result, input_const)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "time = tf.Variable(0.0)\n",
    "with tf.GradientTape() as outer_tape:\n",
    "    with tf.GradientTape() as inner_tape:\n",
    "        position = 4.9 * time**2\n",
    "    speed = inner_tape.gradient(position, time)\n",
    "acceleration = outer_tape.gradient(speed, time)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "##### Making TensorFlow functions fast using compilation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def dense(inputs, W, b):\n",
    "    return tf.nn.relu(tf.matmul(inputs, W) + b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "@tf.function(jit_compile=True)\n",
    "def dense(inputs, W, b):\n",
    "    return tf.nn.relu(tf.matmul(inputs, W) + b)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### An end-to-end example: A linear classifier in pure TensorFlow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "num_samples_per_class = 1000\n",
    "negative_samples = np.random.multivariate_normal(\n",
    "    mean=[0, 3], cov=[[1, 0.5], [0.5, 1]], size=num_samples_per_class\n",
    ")\n",
    "positive_samples = np.random.multivariate_normal(\n",
    "    mean=[3, 0], cov=[[1, 0.5], [0.5, 1]], size=num_samples_per_class\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "inputs = np.vstack((negative_samples, positive_samples)).astype(np.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "targets = np.vstack(\n",
    "    (\n",
    "        np.zeros((num_samples_per_class, 1), dtype=\"float32\"),\n",
    "        np.ones((num_samples_per_class, 1), dtype=\"float32\"),\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.scatter(inputs[:, 0], inputs[:, 1], c=targets[:, 0])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "input_dim = 2\n",
    "output_dim = 1\n",
    "W = tf.Variable(initial_value=tf.random.uniform(shape=(input_dim, output_dim)))\n",
    "b = tf.Variable(initial_value=tf.zeros(shape=(output_dim,)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "def model(inputs, W, b):\n",
    "    return tf.matmul(inputs, W) + b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "def mean_squared_error(targets, predictions):\n",
    "    per_sample_losses = tf.square(targets - predictions)\n",
    "    return tf.reduce_mean(per_sample_losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "learning_rate = 0.1\n",
    "\n",
    "@tf.function(jit_compile=True)\n",
    "def training_step(inputs, targets, W, b):\n",
    "    with tf.GradientTape() as tape:\n",
    "        predictions = model(inputs, W, b)\n",
    "        loss = mean_squared_error(predictions, targets)\n",
    "    grad_loss_wrt_W, grad_loss_wrt_b = tape.gradient(loss, [W, b])\n",
    "    W.assign_sub(grad_loss_wrt_W * learning_rate)\n",
    "    b.assign_sub(grad_loss_wrt_b * learning_rate)\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "for step in range(40):\n",
    "    loss = training_step(inputs, targets, W, b)\n",
    "    print(f\"Loss at step {step}: {loss:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "predictions = model(inputs, W, b)\n",
    "plt.scatter(inputs[:, 0], inputs[:, 1], c=predictions[:, 0] > 0.5)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "x = np.linspace(-1, 4, 100)\n",
    "y = -W[0] / W[1] * x + (0.5 - b) / W[1]\n",
    "plt.plot(x, y, \"-r\")\n",
    "plt.scatter(inputs[:, 0], inputs[:, 1], c=predictions[:, 0] > 0.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### What makes the TensorFlow approach unique"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### Introduction to PyTorch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### First steps with PyTorch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "##### Tensors and parameters in PyTorch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "###### Constant tensors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "torch.ones(size=(2, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "torch.zeros(size=(2, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "torch.tensor([1, 2, 3], dtype=torch.float32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "###### Random tensors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "torch.normal(\n",
    "mean=torch.zeros(size=(3, 1)),\n",
    "std=torch.ones(size=(3, 1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "torch.rand(3, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "###### Tensor assignment and the Parameter class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "x = torch.zeros(size=(2, 1))\n",
    "x[0, 0] = 1.\n",
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "x = torch.zeros(size=(2, 1))\n",
    "p = torch.nn.parameter.Parameter(data=x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "##### Tensor operations: Doing math in PyTorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "a = torch.ones((2, 2))\n",
    "b = torch.square(a)\n",
    "c = torch.sqrt(a)\n",
    "d = b + c\n",
    "e = torch.matmul(a, b)\n",
    "f = torch.cat((a, b), dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "def dense(inputs, W, b):\n",
    "    return torch.nn.relu(torch.matmul(inputs, W) + b)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "##### Computing gradients with PyTorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "input_var = torch.tensor(3.0, requires_grad=True)\n",
    "result = torch.square(input_var)\n",
    "result.backward()\n",
    "gradient = input_var.grad\n",
    "gradient"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "result = torch.square(input_var)\n",
    "result.backward()\n",
    "input_var.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "input_var.grad = None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### An end-to-end example: A linear classifier in pure PyTorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "input_dim = 2\n",
    "output_dim = 1\n",
    "\n",
    "W = torch.rand(input_dim, output_dim, requires_grad=True)\n",
    "b = torch.zeros(output_dim, requires_grad=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "def model(inputs, W, b):\n",
    "    return torch.matmul(inputs, W) + b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "def mean_squared_error(targets, predictions):\n",
    "    per_sample_losses = torch.square(targets - predictions)\n",
    "    return torch.mean(per_sample_losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "learning_rate = 0.1\n",
    "\n",
    "def training_step(inputs, targets, W, b):\n",
    "    predictions = model(inputs)\n",
    "    loss = mean_squared_error(targets, predictions)\n",
    "    loss.backward()\n",
    "    grad_loss_wrt_W, grad_loss_wrt_b = W.grad, b.grad\n",
    "    with torch.no_grad():\n",
    "        W -= grad_loss_wrt_W * learning_rate\n",
    "        b -= grad_loss_wrt_b * learning_rate\n",
    "    W.grad = None\n",
    "    b.grad = None\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "##### Packaging state and computation with the Module class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "class LinearModel(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.W = torch.nn.Parameter(torch.rand(input_dim, output_dim))\n",
    "        self.b = torch.nn.Parameter(torch.zeros(output_dim))\n",
    "\n",
    "    def forward(self, inputs):\n",
    "        return torch.matmul(inputs, self.W) + self.b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "model = LinearModel()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "torch_inputs = torch.tensor(inputs)\n",
    "output = model(torch_inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "def training_step(inputs, targets):\n",
    "    predictions = model(inputs)\n",
    "    loss = mean_squared_error(targets, predictions)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    model.zero_grad()\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "##### Making PyTorch modules fast using compilation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "compiled_model = torch.compile(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "@torch.compile\n",
    "def dense(inputs, W, b):\n",
    "    return torch.nn.relu(torch.matmul(inputs, W) + b)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### What makes the PyTorch approach unique"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### Introduction to JAX"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### First steps with JAX"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Tensors in JAX"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "from jax import numpy as jnp\n",
    "jnp.ones(shape=(2, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "jnp.zeros(shape=(2, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "jnp.array([1, 2, 3], dtype=\"float32\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Random number generation in JAX"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "np.random.normal(size=(3,))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "np.random.normal(size=(3,))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "def apply_noise(x, seed):\n",
    "    np.random.seed(seed)\n",
    "    x = x * np.random.normal((3,))\n",
    "    return x\n",
    "\n",
    "seed = 1337\n",
    "y = apply_noise(x, seed)\n",
    "seed += 1\n",
    "z = apply_noise(x, seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import jax\n",
    "\n",
    "seed_key = jax.random.key(1337)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "seed_key = jax.random.key(0)\n",
    "jax.random.normal(seed_key, shape=(3,))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "seed_key = jax.random.key(123)\n",
    "jax.random.normal(seed_key, shape=(3,))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "jax.random.normal(seed_key, shape=(3,))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "seed_key = jax.random.key(123)\n",
    "jax.random.normal(seed_key, shape=(3,))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "new_seed_key = jax.random.split(seed_key, num=1)[0]\n",
    "jax.random.normal(new_seed_key, shape=(3,))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "##### Tensor assignment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "x = jnp.array([1, 2, 3], dtype=\"float32\")\n",
    "new_x = x.at[0].set(10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "##### Tensor operations: Doing math in JAX"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "a = jnp.ones((2, 2))\n",
    "b = jnp.square(a)\n",
    "c = jnp.sqrt(a)\n",
    "d = b + c\n",
    "e = jnp.matmul(a, b)\n",
    "e *= d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "def dense(inputs, W, b):\n",
    "    return jax.nn.relu(jnp.matmul(inputs, W) + b)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "##### Computing gradients with JAX"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "def compute_loss(input_var):\n",
    "    return jnp.square(input_var)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "grad_fn = jax.grad(compute_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "input_var = jnp.array(3.0)\n",
    "grad_of_loss_wrt_input_var = grad_fn(input_var)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "##### JAX gradient-computation best practices"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "###### Returning the loss value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "grad_fn = jax.value_and_grad(compute_loss)\n",
    "output, grad_of_loss_wrt_input_var = grad_fn(input_var)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "###### Getting gradients for a complex function"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "###### Returning auxiliary outputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "##### Making JAX functions fast with @jax.jit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "@jax.jit\n",
    "def dense(inputs, W, b):\n",
    "    return jax.nn.relu(jnp.matmul(inputs, W) + b)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### An end-to-end example: A linear classifier in pure JAX"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "def model(inputs, W, b):\n",
    "    return jnp.matmul(inputs, W) + b\n",
    "\n",
    "def mean_squared_error(targets, predictions):\n",
    "    per_sample_losses = jnp.square(targets - predictions)\n",
    "    return jnp.mean(per_sample_losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "def compute_loss(state, inputs, targets):\n",
    "    W, b = state\n",
    "    predictions = model(inputs, W, b)\n",
    "    loss = mean_squared_error(targets, predictions)\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "grad_fn = jax.value_and_grad(compute_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "learning_rate = 0.1\n",
    "\n",
    "@jax.jit\n",
    "def training_step(inputs, targets, W, b):\n",
    "    loss, grads = grad_fn((W, b), inputs, targets)\n",
    "    grad_wrt_W, grad_wrt_b = grads\n",
    "    W = W - grad_wrt_W * learning_rate\n",
    "    b = b - grad_wrt_b * learning_rate\n",
    "    return loss, W, b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "input_dim = 2\n",
    "output_dim = 1\n",
    "\n",
    "W = jax.numpy.array(np.random.uniform(size=(input_dim, output_dim)))\n",
    "b = jax.numpy.array(np.zeros(shape=(output_dim,)))\n",
    "state = (W, b)\n",
    "for step in range(40):\n",
    "    loss, W, b = training_step(inputs, targets, W, b)\n",
    "    print(f\"Loss at step {step}: {loss:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### What makes the JAX approach unique"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### Introduction to Keras"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### First steps with Keras"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "##### Picking a backend framework"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ[\"KERAS_BACKEND\"] = \"jax\"\n",
    "\n",
    "import keras"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Layers: The building blocks of deep learning"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "##### The base `Layer` class in Keras"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import keras\n",
    "\n",
    "class SimpleDense(keras.Layer):\n",
    "    def __init__(self, units, activation=None):\n",
    "        super().__init__()\n",
    "        self.units = units\n",
    "        self.activation = activation\n",
    "\n",
    "    def build(self, input_shape):\n",
    "        batch_dim, input_dim = input_shape\n",
    "        self.W = self.add_weight(\n",
    "            shape=(input_dim, self.units), initializer=\"random_normal\"\n",
    "        )\n",
    "        self.b = self.add_weight(shape=(self.units,), initializer=\"zeros\")\n",
    "\n",
    "    def call(self, inputs):\n",
    "        y = keras.ops.matmul(inputs, self.W) + self.b\n",
    "        if self.activation is not None:\n",
    "            y = self.activation(y)\n",
    "        return y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "my_dense = SimpleDense(units=32, activation=keras.ops.relu)\n",
    "input_tensor = keras.ops.ones(shape=(2, 784))\n",
    "output_tensor = my_dense(input_tensor)\n",
    "print(output_tensor.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "##### Automatic shape inference: Building layers on the fly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "from keras import layers\n",
    "\n",
    "layer = layers.Dense(32, activation=\"relu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "from keras import models\n",
    "from keras import layers\n",
    "\n",
    "model = models.Sequential(\n",
    "    [\n",
    "        layers.Dense(32, activation=\"relu\"),\n",
    "        layers.Dense(32),\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "model = keras.Sequential(\n",
    "    [\n",
    "        SimpleDense(32, activation=\"relu\"),\n",
    "        SimpleDense(64, activation=\"relu\"),\n",
    "        SimpleDense(32, activation=\"relu\"),\n",
    "        SimpleDense(10, activation=\"softmax\"),\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### From layers to models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### The \"compile\" step: Configuring the learning process"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "model = keras.Sequential([keras.layers.Dense(1)])\n",
    "model.compile(\n",
    "    optimizer=\"rmsprop\",\n",
    "    loss=\"mean_squared_error\",\n",
    "    metrics=[\"accuracy\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "model.compile(\n",
    "    optimizer=keras.optimizers.RMSprop(),\n",
    "    loss=keras.losses.MeanSquaredError(),\n",
    "    metrics=[keras.metrics.BinaryAccuracy()],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Picking a loss function"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Understanding the fit method"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "history = model.fit(\n",
    "    inputs,\n",
    "    targets,\n",
    "    epochs=5,\n",
    "    batch_size=128,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "history.history"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Monitoring loss and metrics on validation data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "model = keras.Sequential([keras.layers.Dense(1)])\n",
    "model.compile(\n",
    "    optimizer=keras.optimizers.RMSprop(learning_rate=0.1),\n",
    "    loss=keras.losses.MeanSquaredError(),\n",
    "    metrics=[keras.metrics.BinaryAccuracy()],\n",
    ")\n",
    "\n",
    "indices_permutation = np.random.permutation(len(inputs))\n",
    "shuffled_inputs = inputs[indices_permutation]\n",
    "shuffled_targets = targets[indices_permutation]\n",
    "\n",
    "num_validation_samples = int(0.3 * len(inputs))\n",
    "val_inputs = shuffled_inputs[:num_validation_samples]\n",
    "val_targets = shuffled_targets[:num_validation_samples]\n",
    "training_inputs = shuffled_inputs[num_validation_samples:]\n",
    "training_targets = shuffled_targets[num_validation_samples:]\n",
    "model.fit(\n",
    "    training_inputs,\n",
    "    training_targets,\n",
    "    epochs=5,\n",
    "    batch_size=16,\n",
    "    validation_data=(val_inputs, val_targets),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Inference: Using a model after training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "predictions = model.predict(val_inputs, batch_size=128)\n",
    "print(predictions[:10])"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "chapter03_introduction-to-ml-frameworks",
   "private_outputs": false,
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}